{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 2.1.0\n",
      "\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "loading embedding ...\n",
      "embedding loaded as `glove_embedding`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, \n",
    "                                                  fixed_length_right=100, \n",
    "                                                  remove_stop_words=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9365.13it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5163.33it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 848689.58it/s]\n",
      "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 151196.34it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 89439.89it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 706219.56it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 736492.25it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 3032613.85it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9550.46it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5185.71it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 144261.80it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 239784.49it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 131422.51it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 620662.05it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 771237.80it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 140320.27it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 95088.12it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9513.19it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5119.33it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 10854.85it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 113134.00it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 129446.66it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 185736.87it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 583778.42it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 90215.99it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 93243.92it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 8866.58it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 5131.75it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 147786.31it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 119176.36it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 111582.90it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 354132.54it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 543579.15it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 103730.57it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 89680.20it/s]\n"
     ]
    }
   ],
   "source": [
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "dev_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7efc21e7ce48>,\n",
       " 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7efc16b9ec88>,\n",
       " 'vocab_size': 16674,\n",
       " 'embedding_input_dim': 16674,\n",
       " 'input_shapes': [(10,), (100,)]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model_class                   <class 'matchzoo.models.arcii.ArcII'>\n",
      "input_shapes                  [(10,), (100,)]\n",
      "task                          Ranking Task\n",
      "optimizer                     adam\n",
      "with_embedding                True\n",
      "embedding_input_dim           16674\n",
      "embedding_output_dim          100\n",
      "embedding_trainable           True\n",
      "num_blocks                    2\n",
      "kernel_1d_count               32\n",
      "kernel_1d_size                3\n",
      "kernel_2d_count               [64, 64]\n",
      "kernel_2d_size                [3, 3]\n",
      "activation                    relu\n",
      "pool_2d_size                  [[3, 3], [3, 3]]\n",
      "padding                       same\n",
      "dropout_rate                  0.0\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.ArcII()\n",
    "\n",
    "# load `input_shapes` and `embedding_input_dim` (vocab_size)\n",
    "model.params.update(preprocessor.context)\n",
    "\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_output_dim'] = 100\n",
    "model.params['embedding_trainable'] = True\n",
    "model.params['num_blocks'] = 2\n",
    "model.params['kernel_1d_count'] = 32\n",
    "model.params['kernel_1d_size'] = 3\n",
    "model.params['kernel_2d_count'] = [64, 64]\n",
    "model.params['kernel_2d_size'] = [3, 3]\n",
    "model.params['pool_2d_size'] = [[3, 3], [3, 3]]\n",
    "model.params['optimizer'] = 'adam'\n",
    "\n",
    "model.build()\n",
    "model.compile()\n",
    "\n",
    "print(model.params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 100)          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             1667400     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_1 (Conv1D)               (None, 10, 32)       9632        embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_2 (Conv1D)               (None, 100, 32)      9632        embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "matching_layer_1 (MatchingLayer (None, 10, 100, 32)  0           conv1d_1[0][0]                   \n",
      "                                                                 conv1d_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_1 (Conv2D)               (None, 10, 100, 64)  18496       matching_layer_1[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2D)  (None, 3, 33, 64)    0           conv2d_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_2 (Conv2D)               (None, 3, 33, 64)    36928       max_pooling2d_1[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2D)  (None, 1, 11, 64)    0           conv2d_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 704)          0           max_pooling2d_2[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 704)          0           flatten_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 1)            705         dropout_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 1,742,793\n",
      "Trainable params: 1,742,793\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_x, test_y = test_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=test_x, y=test_y, batch_size=len(test_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num batches: 102\n"
     ]
    }
   ],
   "source": [
    "train_generator = mz.DataGenerator(\n",
    "    train_pack_processed,\n",
    "    mode='pair',\n",
    "    num_dup=2,\n",
    "    num_neg=1,\n",
    "    batch_size=20\n",
    ")\n",
    "print('num batches:', len(train_generator))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 6s 61ms/step - loss: 0.6123\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6165843661519929 - normalized_discounted_cumulative_gain@5(0.0): 0.6639951229938149 - mean_average_precision(0.0): 0.6255013171310638\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 11s 107ms/step - loss: 0.3213\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5795022399236163 - normalized_discounted_cumulative_gain@5(0.0): 0.6400033764936961 - mean_average_precision(0.0): 0.5927358064245385\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 12s 116ms/step - loss: 0.1556\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5734366624818106 - normalized_discounted_cumulative_gain@5(0.0): 0.6320835504730066 - mean_average_precision(0.0): 0.5873376891991933\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 12s 121ms/step - loss: 0.0966\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6047140052395196 - normalized_discounted_cumulative_gain@5(0.0): 0.6646231653619664 - mean_average_precision(0.0): 0.6177779169838045\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 12s 117ms/step - loss: 0.0688\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5804830097430296 - normalized_discounted_cumulative_gain@5(0.0): 0.6397107091735946 - mean_average_precision(0.0): 0.5861104257070671\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 12s 119ms/step - loss: 0.0741\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5733720686699197 - normalized_discounted_cumulative_gain@5(0.0): 0.6297166660287926 - mean_average_precision(0.0): 0.5793064781802681\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 12s 118ms/step - loss: 0.0545\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5612000374349604 - normalized_discounted_cumulative_gain@5(0.0): 0.6275035028226839 - mean_average_precision(0.0): 0.5812912862211518\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 13s 129ms/step - loss: 0.0351\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5695536121652267 - normalized_discounted_cumulative_gain@5(0.0): 0.6230819298416224 - mean_average_precision(0.0): 0.5829102115308369\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 13s 127ms/step - loss: 0.0546\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5747202712592626 - normalized_discounted_cumulative_gain@5(0.0): 0.6283361484936623 - mean_average_precision(0.0): 0.5858063874997032\n",
      "Epoch 10/30\n",
      "102/102 [==============================] - 12s 120ms/step - loss: 0.0363\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5608760792872554 - normalized_discounted_cumulative_gain@5(0.0): 0.6165261187651955 - mean_average_precision(0.0): 0.5804023750891116\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 12s 121ms/step - loss: 0.0316\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5526488015162425 - normalized_discounted_cumulative_gain@5(0.0): 0.6190121151432394 - mean_average_precision(0.0): 0.5724829930234496\n",
      "Epoch 12/30\n",
      "102/102 [==============================] - 12s 122ms/step - loss: 0.0286\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5574638704067251 - normalized_discounted_cumulative_gain@5(0.0): 0.6221425949954883 - mean_average_precision(0.0): 0.5790691069305163\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 14s 135ms/step - loss: 0.0232\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5530705667805896 - normalized_discounted_cumulative_gain@5(0.0): 0.6045382570104455 - mean_average_precision(0.0): 0.560135777928777\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 12s 122ms/step - loss: 0.0232\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5551975290816926 - normalized_discounted_cumulative_gain@5(0.0): 0.6199269060404192 - mean_average_precision(0.0): 0.5695024812501234\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 12s 120ms/step - loss: 0.0186\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5798942973253254 - normalized_discounted_cumulative_gain@5(0.0): 0.6326798983491299 - mean_average_precision(0.0): 0.5914300748272545\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 12s 120ms/step - loss: 0.0224\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5549313451763698 - normalized_discounted_cumulative_gain@5(0.0): 0.609066461479371 - mean_average_precision(0.0): 0.5696612552826874\n",
      "Epoch 17/30\n",
      "102/102 [==============================] - 14s 134ms/step - loss: 0.0320\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5362590661756249 - normalized_discounted_cumulative_gain@5(0.0): 0.5960587553325881 - mean_average_precision(0.0): 0.5529867856626977\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 13s 123ms/step - loss: 0.0231\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5709133081694716 - normalized_discounted_cumulative_gain@5(0.0): 0.6294306655133045 - mean_average_precision(0.0): 0.5833828714158493\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 12s 120ms/step - loss: 0.0210\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5408055010433294 - normalized_discounted_cumulative_gain@5(0.0): 0.6188813340014495 - mean_average_precision(0.0): 0.5639203433881982\n",
      "Epoch 20/30\n",
      "102/102 [==============================] - 11s 111ms/step - loss: 0.0099\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5857209979914645 - normalized_discounted_cumulative_gain@5(0.0): 0.6369323440919344 - mean_average_precision(0.0): 0.59214541005708\n",
      "Epoch 21/30\n",
      "102/102 [==============================] - 11s 108ms/step - loss: 0.0314\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5621754667046626 - normalized_discounted_cumulative_gain@5(0.0): 0.6250852206304124 - mean_average_precision(0.0): 0.5746827039867791\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 11s 108ms/step - loss: 0.0346\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5245291889962129 - normalized_discounted_cumulative_gain@5(0.0): 0.5957453960380752 - mean_average_precision(0.0): 0.5563949155969331\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 11s 104ms/step - loss: 0.0180\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5649023738386006 - normalized_discounted_cumulative_gain@5(0.0): 0.6281559244518438 - mean_average_precision(0.0): 0.5862423847100566\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 11s 110ms/step - loss: 0.0326\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5595631369578311 - normalized_discounted_cumulative_gain@5(0.0): 0.6216538205632761 - mean_average_precision(0.0): 0.5770444731020851\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 11s 107ms/step - loss: 0.0212\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5611397782325962 - normalized_discounted_cumulative_gain@5(0.0): 0.6275063453800273 - mean_average_precision(0.0): 0.5817784786583774\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 11s 110ms/step - loss: 0.0297\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5332907973060383 - normalized_discounted_cumulative_gain@5(0.0): 0.5988259655236242 - mean_average_precision(0.0): 0.5647126123155829\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 11s 108ms/step - loss: 0.0244\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5448159675702459 - normalized_discounted_cumulative_gain@5(0.0): 0.5945159262523649 - mean_average_precision(0.0): 0.5558586926930239\n",
      "Epoch 28/30\n",
      "102/102 [==============================] - 11s 108ms/step - loss: 0.0157\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.570165361029481 - normalized_discounted_cumulative_gain@5(0.0): 0.6316999513916105 - mean_average_precision(0.0): 0.5947381662168211\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 11s 107ms/step - loss: 0.0360\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.586639186111628 - normalized_discounted_cumulative_gain@5(0.0): 0.6352015410005084 - mean_average_precision(0.0): 0.596046073559104\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 11s 108ms/step - loss: 0.0149\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.57738099248289 - normalized_discounted_cumulative_gain@5(0.0): 0.63122082754378 - mean_average_precision(0.0): 0.5877516720454762\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, \n",
    "                              epochs=30, \n",
    "                              callbacks=[evaluate], \n",
    "                              workers=30, \n",
    "                              use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "Use this function to update the README.md with a better set of parameters.\n",
    "Make sure you delete the correct section of the README.md before calling this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "append_params_to_readme(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
